"""
# Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import threading
import time
from multiprocessing.managers import (AcquirerProxy, BaseManager, ListProxy,
                                      Value, ValueProxy)
from queue import Queue
from typing import Any, List, Tuple

import numpy as np

from fastdeploy.utils import llm_logger


class EngineWorkerQueue:
    """
    Cross-machine and cross-process communication queue between Engine and Worker.
    Manages shared resources using multiprocessing managers for inter-process communication.
    """

    def __init__(
            self,
            address: Tuple[str, int] = ('0.0.0.0', 5000),
            authkey: bytes = b'secret_key',
            is_server: bool = False,
            num_client: int = 1,  # tensor parallel size
            client_id: int = -1,  # tensor parallel id
            local_data_parallel_size: int = 1,  # data parallel size
            local_data_parallel_id: int = 0,  # local data parallel id
    ) -> None:
        """
        Initialize the communication queue.

        Args:
            address: Network address (IP, port) for the queue server
            authkey: Authentication key for secure connection
            is_server: Whether this instance acts as a server
            num_client: Total number of expected clients
            client_id: Unique identifier for client instances
        """
        self.address: Tuple[str, int] = address
        self.authkey: bytes = authkey
        self.is_server: bool = is_server
        self.num_client: int = num_client
        self.client_id: int = client_id
        self.local_data_parallel_size = local_data_parallel_size
        self.local_data_parallel_id = local_data_parallel_id

        class QueueManager(BaseManager):
            """
            Custom QueueManager for proxy object registration.
            """
            pass

        if is_server:
            # Server-side initialization for shared resources
            self.tasks_init: List[List[Any]] = [
                list() for _ in range(self.local_data_parallel_size)
            ]
            self.client_read_flag_init: List[List[int]] = [
                [1] * self.num_client
                for _ in range(self.local_data_parallel_size)
            ]
            self.lock_init: List[threading.Lock] = [
                threading.Lock() for _ in range(self.local_data_parallel_size)
            ]
            self.read_finish_flag_init: List[Value] = [
                Value("i", 0) for _ in range(self.local_data_parallel_size)
            ]
            self.connected_client_counter_init: List[Value] = [
                Value("i", 0) for _ in range(self.local_data_parallel_size)
            ]
            self.finished_req_queue = [
                Queue() for _ in range(self.local_data_parallel_size)
            ]
            self.cache_infos_init: List[List[Any]] = [
                list() for _ in range(self.local_data_parallel_size)
            ]
            self.client_read_info_flag_init: List[List[int]] = [
                [1] * self.num_client
                for _ in range(self.local_data_parallel_size)
            ]
            self.lock_info_init: List[threading.Lock] = [
                threading.Lock() for _ in range(self.local_data_parallel_size)
            ]

            self.finish_request_barrier = [
                threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
            ]

            # Register shared objects with proxy types
            QueueManager.register("get_tasks",
                                  callable=lambda idx: self.tasks_init[idx],
                                  proxytype=ListProxy)
            QueueManager.register(
                "get_client_read_flag",
                callable=lambda idx: self.client_read_flag_init[idx],
                proxytype=ListProxy)
            QueueManager.register("get_lock",
                                  callable=lambda idx: self.lock_init[idx],
                                  proxytype=AcquirerProxy)
            QueueManager.register(
                "get_read_finish_flag",
                callable=lambda idx: self.read_finish_flag_init[idx],
                proxytype=ValueProxy)
            QueueManager.register(
                "get_connected_client_counter",
                callable=lambda idx: self.connected_client_counter_init[idx],
                proxytype=ValueProxy)

            QueueManager.register(
                'get_finish_request_queue',
                callable=lambda idx: self.finished_req_queue[idx])

            QueueManager.register(
                "get_cache_infos",
                callable=lambda idx: self.cache_infos_init[idx],
                proxytype=ListProxy)

            QueueManager.register(
                "get_client_read_info_flag",
                callable=lambda idx: self.client_read_info_flag_init[idx],
                proxytype=ListProxy)
            QueueManager.register(
                "get_lock_info",
                callable=lambda idx: self.lock_info_init[idx],
                proxytype=AcquirerProxy)

            self.disaggregate_requests = [
                Queue() for _ in range(self.local_data_parallel_size)
            ]
            QueueManager.register(
                "get_disaggregate_requests",
                callable=lambda idx: self.disaggregate_requests[idx])

            self.available_prefill_instances = Queue()
            QueueManager.register(
                "get_available_prefill_instances",
                callable=lambda: self.available_prefill_instances)
 
            QueueManager.register(
                "get_finish_request_barrier",
                callable=lambda idx: self.finish_request_barrier[idx])
            self.manager: BaseManager = QueueManager(address=self.address,
                                                     authkey=self.authkey)
            self.manager.start()
        else:
            # Client-side connection setup
            assert self.client_id >= 0 and self.client_id < self.num_client, (
                f"self.client_id={self.client_id}, self.num_client={self.num_client}"
            )
            QueueManager.register("get_tasks")
            QueueManager.register("get_client_read_flag")
            QueueManager.register("get_lock")
            QueueManager.register("get_read_finish_flag")
            QueueManager.register("get_connected_client_counter")
            QueueManager.register("get_finish_request_queue")
            QueueManager.register("get_cache_infos")
            QueueManager.register("get_client_read_info_flag")
            QueueManager.register("get_lock_info")
            QueueManager.register("get_disaggregate_requests")
            QueueManager.register("get_available_prefill_instances")
            QueueManager.register("get_finish_request_barrier")
            self.manager = QueueManager(address=self.address,
                                        authkey=self.authkey)
            self._connect_with_retry()

            # Get proxy objects for shared resources
            self.tasks: ListProxy = self.manager.get_tasks(
                self.local_data_parallel_id)
            self.client_read_flag: ListProxy = self.manager.get_client_read_flag(
                self.local_data_parallel_id)
            self.lock: AcquirerProxy = self.manager.get_lock(
                self.local_data_parallel_id)
            self.read_finish_flag: ValueProxy = self.manager.get_read_finish_flag(
                self.local_data_parallel_id)
            self.connected_client_counter: ValueProxy = \
                self.manager.get_connected_client_counter(self.local_data_parallel_id)
            self.cache_infos: ListProxy = self.manager.get_cache_infos(
                self.local_data_parallel_id)
            self.client_read_info_flag: ListProxy = self.manager.get_client_read_info_flag(
                self.local_data_parallel_id)
            self.lock_info: AcquirerProxy = self.manager.get_lock_info(
                self.local_data_parallel_id)

            # p/d 分离获取
            self.disaggregate_requests = self.manager.get_disaggregate_requests(
                self.local_data_parallel_id)
            self.available_prefill_instances = self.manager.get_available_prefill_instances()
            self.finish_request_barrier = self.manager.get_finish_request_barrier(
                self.local_data_parallel_id
            )
            self.finished_req_queue = self.manager.get_finish_request_queue(
                self.local_data_parallel_id)
            assert self.num_client == len(self.client_read_flag)

        if is_server:
            llm_logger.info("EngineWorkerQueue server started.")
        else:
            # Update client connection counter
            self.lock.acquire()
            self.connected_client_counter.set(
                self.connected_client_counter.get() + 1)
            self.lock.release()
            llm_logger.info((
                f"Connected EngineWorkerQueue client_id: {self.client_id}, number "
                f"of connected clients: {self.connected_client_counter.get()}"
            ))

    def _connect_with_retry(self,
                            max_retries: int = 5,
                            interval: int = 3) -> None:
        """
        Connect to the server with retry mechanism.

        Args:
            max_retries: Maximum connection attempts
            interval: Retry interval in seconds

        Raises:
            ConnectionError: If all connection attempts fail
        """
        for _ in range(max_retries):
            try:
                self.manager.connect()
                return
            except ConnectionRefusedError:
                time.sleep(interval)
        raise ConnectionError(f"TaskQueue cannot connect {self.address}")

    def put_tasks(self, tasks: List[Any]) -> None:
        """
        Add tasks to the shared queue in a thread-safe manner.
        Waits until all clients have read previous tasks before adding new ones.

        Args:
            tasks: Tasks to be added to the queue
        """
        self.lock.acquire()
        while sum(self.client_read_flag) < self.num_client:
            self.lock.release()
            time.sleep(0.001)
            self.lock.acquire()

        self.tasks[:] = list()
        self.client_read_flag[:] = [0] * self.num_client
        self.tasks.append(tasks)
        self.lock.release()

    def get_tasks(self) -> Tuple[List[Any], bool]:
        """
        Retrieve tasks from the shared queue and update read status.

        Returns:
            tuple: (list of tasks, bool indicating if all clients have read)
        """
        tasks: List[Any] = list()
        self.lock.acquire()
        tasks.extend(self.tasks)
        self.client_read_flag[self.client_id] = 1
        all_client_read: bool = np.sum(
            self.client_read_flag) == self.num_client
        if all_client_read:
            self.tasks[:] = list()
        self.lock.release()
        return tasks, all_client_read

    def num_tasks(self) -> int:
        """
        Get current number of tasks in the queue.

        Returns:
            int: Total number of tasks
        """
        self.lock.acquire()
        total_num: int = len(self.tasks)
        self.lock.release()
        return total_num
    
    def get_prefill_instances(self):
        """
        check if the prefill queue is empty
        """
        if self.available_prefill_instances.qsize() == 0:
            return 0
        else:
            return self.available_prefill_instances.get()


    def put_cache_info(self, cache_info) -> None:
        """
        Args:
            tasks: Tasks to be added to the queue
        """
        self.lock_info.acquire()
        while sum(self.client_read_info_flag) < self.num_client:
            self.lock_info.release()
            time.sleep(0.001)
            self.lock_info.acquire()

        self.cache_infos[:] = list()
        self.client_read_info_flag[:] = [0] * self.num_client

        self.cache_infos.extend(cache_info)
        llm_logger.debug(
            f"cache_infos: {self.cache_infos}  local_data_parallel_id:{self.local_data_parallel_id}"
        )
        self.lock_info.release()

    def get_cache_info(self) -> List[Any]:
        """
        Retrieve tasks from the shared queue and update read status.

        Returns:
            tuple: (list of tasks, bool indicating if all clients have read)
        """
        cache_infos: List[Any] = list()
        self.lock_info.acquire()
        if self.client_read_info_flag[self.client_id] == 1:
            self.lock_info.release()
            return cache_infos
        cache_infos.extend(self.cache_infos)
        self.client_read_info_flag[self.client_id] = 1
        all_client_read: bool = np.sum(
            self.client_read_info_flag) == self.num_client
        if all_client_read:
            self.cache_infos[:] = list()
        self.lock_info.release()
        if len(cache_infos) != 0:
            llm_logger.debug(
                f"get cache infos: {cache_infos}  local_data_parallel_id:{self.local_data_parallel_id}"
            )
        return cache_infos
    
    def num_cache_infos(self) -> int:
        """
        Get current number of tasks in the queue.

        Returns:
            int: Total number of tasks
        """
        self.lock_info.acquire()
        total_num: int = len(self.cache_infos)
        self.lock_info.release()
        return total_num

    def put_finished_req(self, req_ids) -> None:
        """
        Put finished request ID into the queue.

        Args:
            req_ids: Request ID to be added to the queue
        """
        self.finished_req_queue.put(req_ids)

    def get_finished_req(self) -> str:
        """
        Get finished request ID from the queue.

        Returns:
            str: Finished request ID
        """
        ans = []
        if self.finished_req_queue.empty():
            return ans
        ans = self.finished_req_queue.get()
        llm_logger.debug(f"get finished req: {ans}")
        return ans

    def disaggregate_queue_empty(self):
        """
        Check if the disaggregated task queue is empty.
        """
        return self.disaggregate_requests.qsize() == 0

    def put_disaggregated_tasks(self, item):
        """
        put disaggregated tasks to the queue
        """
        llm_logger.debug("put item to queue")
        self.disaggregate_requests.put(item)
        llm_logger.debug("put item to queue success")

    def get_disaggregated_tasks(self):
        """
        get disaggregated tasks from the queue
        """
        llm_logger.debug("get tasks from queue")
        if self.disaggregate_requests.qsize() == 0:
            return None
        item = []
        while not self.disaggregate_requests.empty():
            item.append(self.disaggregate_requests.get())
        llm_logger.debug("get tasks from queue success")
        return item

    def cleanup(self):
        """
        Exit the worker queue gracefully.
        """
        if self.manager is not None and self.is_server:
            self.manager.shutdown()
